'''
Author: SlytherinGe
LastEditTime: 2021-02-24 13:36:06
'''
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.backends.cudnn as cudnn

from models import create_resnetX_from_pretrained
from data import *
'''
param:cfg{
    model{
        pretrained_path,
        resnet_type,
        num_classes,
        save_path,
    }
    dataset{
        train_data_root,
        val_data_root,
        train_transform,
        val_transform,
    }
    hyperparam{
        batch_size,
        num_workers,
        lr,
        momentum,
        epoch,
    }
}
'''

def get_scheduler(optimizer ,cfg):
    scheduer = None
    if cfg['hyperparam']['lr_schedule'] == 'cosine':
        scheduer = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                        cfg['hyperparam']['epoch'] - 1,
                                                        cfg['hyperparam']['lr']*0.001)
    elif cfg['hyperparam']['lr_schedule'] == 'multisteps':
        scheduer = optim.lr_scheduler.MultiStepLR(  optimizer,
                                                    cfg['hyperparam']['steps'],
                                                    cfg['hyperparam']['gamma'])
    elif cfg['hyperparam']['lr_schedule'] == 'exponential':
        scheduer = optim.lr_scheduler.ExponentialLR(optimizer,
                                                    cfg['hyperparam']['gamma'])

    return scheduer

def train_model(cfg):

    # create dataset
    train_dataset = STJUSarShipDataset(cfg['dataset']['train_data_root'], cfg['dataset']['train_transform'])
    val_dataset = STJUSarShipDataset(cfg['dataset']['val_data_root'], cfg['dataset']['val_transform'])
    # creat model
    model_resnet = create_resnetX_from_pretrained(cfg['model']['resnet_type'],
            cfg['model']['num_classes'], cfg['model']['pretrained_path'])
    # init model
    net = torch.nn.DataParallel(model_resnet)
    cudnn.benchmark = True
    # create optimizer
    optimizer = optim.SGD(net.parameters(), cfg['hyperparam']['lr'],
                                cfg['hyperparam']['momentum'],
                                weight_decay=5e-4)
    # create loss function
    loss_fn = nn.CrossEntropyLoss(reduction='mean')
    classify_loss = 0
    # init dataloader
    train_loader = data.DataLoader(train_dataset, cfg['hyperparam']['batch_size'],
                                    num_workers=cfg['hyperparam']['num_workers'],
                                    shuffle=True,
                                    pin_memory=True)
    val_loader = data.DataLoader(val_dataset, 1, 1)
    '''
    # code for test
    batch_iterator = iter(train_loader)
    img, label = batch_iterator.next()
    print(label)
    '''
    # start training
    EPOCH = cfg['hyperparam']['epoch']
    # lr_decay_step = cfg['hyperparam']['steps']
    # step = 1
    scheduler = get_scheduler(optimizer, cfg)
    correct_cnt, pred_cnt = 0, 0
    iter_per_epoch = int(train_dataset.__len__() / cfg['hyperparam']['batch_size'])
    best_epoch, best_epoch_accuracy = 1, 0
    net.train()
    print('training start!')
    for epoch in range(EPOCH):
        batch_iterator = iter(train_loader)
        iteration = 0   # iter count
        total_loss , total_accuracy = (0.0, 0.0)
            # decay learning rate
        # if epoch + 1 in lr_decay_step:
        #     adjust_learning_rate(cfg['hyperparam']['lr'], optimizer, 0.1, step)
        #     step += 1
        while True:
            try:
                images, targets = next(batch_iterator)
            except StopIteration:
                # batch end
                break    
            # using cuda
            images = images.cuda()
            targets = targets.cuda()
            # forward
            out = net(images)
            # backprop
            optimizer.zero_grad()
            loss = loss_fn(out, targets.max(dim=1)[1])
            total_loss += loss
            loss.backward()
            optimizer.step()
            iteration += 1
            # calc accuracy
            pred, correct = calc_num_correct_one_batch(targets, out)
            pred_cnt += pred
            correct_cnt += correct
            total_accuracy += float(correct_cnt) / pred_cnt
            if iteration % 10 == 0:
                print("iter:[{:d}/{:d}] epoch:[{:d}/{:d}] lr:{} loss:{:.4f} accuracy:{:.4f}".format(
                        iteration, iter_per_epoch, epoch+1, EPOCH, 
                        scheduler.get_lr(),
                        loss.item(), float(correct_cnt) / pred_cnt
                ))
                pred_cnt, correct_cnt = 0, 0
        print('average loss is: {:.3f}  average accuracy is: {:.3f}'.format(total_loss/iteration, total_accuracy/iteration))
        print('saving model...')
        torch.save(model_resnet.state_dict(),os.path.join(cfg['model']['local_root'],'tmp/','{}_{}.pth'.format(cfg['model']['resnet_type'], epoch+1)))
        print('starting to evaluate model...')
        eval_accuracy = eval_model_during_training(net, val_loader, val_dataset.__len__())
        if eval_accuracy > best_epoch_accuracy:
            best_epoch = epoch + 1
            best_epoch_accuracy = eval_accuracy
        print("epoch [{:d}] finished, best epoch is epoch [{:d}] so far".format(epoch + 1, best_epoch))
        scheduler.step()
    print('training done! best epoch during training is epoch [{:d}], best accuracy is {:.4f}'.format(best_epoch, best_epoch_accuracy * 100.0))
        

def eval_model_during_training(net, testloader, test_size):
    pred_cnt, corr_cnt = 0, 0
    iteration = 0
    test_iter = iter(testloader)
    net.eval()
    with torch.no_grad():
        while True:
            try:
                images, targets = next(test_iter)
            except StopIteration:
                break
            out = net(images.cuda()) 
            pred, correct = calc_num_correct_one_batch(targets.cuda(), out)
            pred_cnt += pred
            corr_cnt += correct
            iteration += 1
            if iteration % 100 == 0:
                print('eval [{:.3f}%/100%]'.format(iteration / test_size * 100.0))
        accuracy = float(corr_cnt) / pred_cnt
        print('eval result: accuracy:{:.4f}%'.format(accuracy*100.0))
    net.train()
    return accuracy



def adjust_learning_rate(initial_lr, optimizer, gamma, step):
    """Sets the learning rate to the initial LR decayed by 10 at every
        specified step
    # Adapted from PyTorch Imagenet example:
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    lr = initial_lr * (gamma ** (step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def calc_num_correct_one_batch(label, predict):
    
    softmax = nn.Softmax(dim = 1)
    out = softmax(predict)
    pred_index = out.max(dim=1)[1]
    label_index = label.max(dim=1)[1]
    # convert bool to int
    result = (pred_index == label_index) + 0
    cnt, correct = result.size(0), result.sum().item()
    return cnt, correct


if __name__ == '__main__':

    cfg = {
        'model': {
            'pretrained_path': None,
            'resnet_type': 'resnet101',
            'num_classes': 16,
            'local_root': r'D:\develop\pytorch-resnet-for-huawei-modelarts\save',
        },
        'dataset':{
            'train_data_root': r'D:\develop\Deep Learning\DataSet\SJTU\OpenSARShip Total\OpenSARShip_total',
            'val_data_root': r'D:\develop\Deep Learning\DataSet\SJTU\Sentinel-ship-45',
            'train_transform': PreprocessTransform(224,
                                        rgb_means=(0,0,0),
                                        rgb_std=(1,1,1)),
            'val_transform' : BaseTransform(224,
                                        rgb_means=(0,0,0),
                                        rgb_std=(1,1,1)),
        },
        'hyperparam':{
            'batch_size': 16,
            'num_workers': 8,
            'lr': 1e-4,
            'momentum': 0.9,
            'epoch': 20,
            'steps': None,
            'lr_schedule' : 'cosine',
            'gamma' : 0,
        }
    }
    train_model(cfg)

